archived/rl_gamerserver_ray/common/sagemaker_rl/coach_launcher.py (230 lines of code) (raw):
import argparse
import copy
import glob
import logging
import os
import re
import shutil
import sys
import rl_coach.core_types
from rl_coach import logger
from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters
from rl_coach.agents.policy_gradients_agent import PolicyGradientsAgentParameters
from rl_coach.base_parameters import Frameworks, TaskParameters, VisualizationParameters
from rl_coach.coach import CoachLauncher
from rl_coach.core_types import MaxDumpFilter, RunPhase, SelectedPhaseOnlyDumpFilter
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
from rl_coach.graph_managers.graph_manager import ScheduleParameters
from rl_coach.logger import screen
from rl_coach.utils import short_dynamic_import
from .configuration_list import ConfigurationList
screen.set_use_colors(False) # Simple text logging so it looks good in CloudWatch
class CoachConfigurationList(ConfigurationList):
"""Helper Object for converting CLI arguments (or SageMaker hyperparameters)
into Coach configuration.
"""
# Being security-paranoid and not instantiating any arbitrary string the customer passes in
ALLOWED_TYPES = {
"Frames": rl_coach.core_types.Frames,
"EnvironmentSteps": rl_coach.core_types.EnvironmentSteps,
"EnvironmentEpisodes": rl_coach.core_types.EnvironmentEpisodes,
"TrainingSteps": rl_coach.core_types.TrainingSteps,
"Time": rl_coach.core_types.Time,
}
class SageMakerCoachPresetLauncher(CoachLauncher):
"""Base class for training RL tasks using RL-Coach.
Customers subclass this to define specific kinds of workloads, overriding these methods as needed.
"""
def __init__(self):
super().__init__()
self.hyperparams = None
def get_config_args(self, parser: argparse.ArgumentParser) -> argparse.Namespace:
"""Overrides the default CLI parsing.
Sets the configuration parameters for what a SageMaker run should do.
Note, this does not support the "play" mode.
"""
# first, convert the parser to a Namespace object with all default values.
empty_arg_list = []
args, _ = parser.parse_known_args(args=empty_arg_list)
parser = self.sagemaker_argparser()
sage_args, unknown = parser.parse_known_args()
# Now fill in the args that we care about.
sagemaker_job_name = os.environ.get("sagemaker_job_name", "sagemaker-experiment")
args.experiment_name = logger.get_experiment_name(sagemaker_job_name)
# Override experiment_path used for outputs
args.experiment_path = "/opt/ml/output/intermediate"
rl_coach.logger.experiment_path = "/opt/ml/output/intermediate" # for gifs
args.checkpoint_save_dir = "/opt/ml/output/data/checkpoint"
args.checkpoint_save_secs = 10 # should avoid hardcoding
# onnx for deployment for mxnet (not tensorflow)
save_model = sage_args.save_model == 1
backend = os.getenv("COACH_BACKEND", "tensorflow")
if save_model and backend == "mxnet":
args.export_onnx_graph = True
args.no_summary = True
args.num_workers = sage_args.num_workers
args.framework = Frameworks[backend]
args.preset = sage_args.RLCOACH_PRESET
# args.apply_stop_condition = True # uncomment for old coach behaviour
self.hyperparameters = CoachConfigurationList()
if len(unknown) % 2 == 1:
raise ValueError("Odd number of command-line arguments specified. Key without value.")
for i in range(0, len(unknown), 2):
name = unknown[i]
if name.startswith("--"):
name = name[2:]
else:
raise ValueError("Unknown command-line argument %s" % name)
val = unknown[i + 1]
self.map_hyperparameter(name, val)
return args
def map_hyperparameter(self, name, value):
"""This is a good method to override where customers can specify custom shortcuts
for hyperparameters. Default takes everything starting with "rl." and sends it
straight to the graph manager.
"""
if name.startswith("rl."):
self.apply_hyperparameter(name, value)
else:
raise ValueError("Unknown hyperparameter %s" % name)
def apply_hyperparameter(self, name, value):
"""Save this hyperparameter to be applied to the graph_manager object when
it's ready.
"""
print("Applying RL hyperparameter %s=%s" % (name, value))
self.hyperparameters.store(name, value)
def default_preset_name(self):
"""
Sub-classes will typically return a single hard-coded string.
"""
try:
# TODO: remove this after converting all samples.
default_preset = self.DEFAULT_PRESET
screen.warning(
"Deprecated configuration of default preset. Please implement default_preset_name()"
)
return default_preset
except:
pass
raise NotImplementedError(
"Sub-classes must specify the name of the default preset "
+ "for this RL problem. This will be the name of a python "
+ "file (without .py) that defines a graph_manager variable"
)
def sagemaker_argparser(self) -> argparse.ArgumentParser:
"""
Expose only the CLI arguments that make sense in the SageMaker context.
"""
parser = argparse.ArgumentParser()
# Arguably this would be cleaner if we copied the config from the base class argparser.
parser.add_argument(
"-n",
"--num_workers",
help="(int) Number of workers for multi-process based agents, e.g. A3C",
default=1,
type=int,
)
parser.add_argument(
"-p",
"--RLCOACH_PRESET",
help="(string) Name of the file with the RLCoach preset",
default=self.default_preset_name(),
type=str,
)
parser.add_argument(
"--save_model",
help="(int) Flag to save model artifact after training finish",
default=0,
type=int,
)
return parser
def path_of_main_launcher(self):
"""
A bit of python magic to find the path of the file that launched the current process.
"""
main_mod = sys.modules["__main__"]
try:
launcher_file = os.path.abspath(sys.modules["__main__"].__file__)
return os.path.dirname(launcher_file)
except AttributeError:
# If __main__.__file__ is missing, then we're probably in an interactive python shell
return os.getcwd()
def preset_from_name(self, preset_name):
preset_path = self.path_of_main_launcher()
print("Loading preset %s from %s" % (preset_name, preset_path))
preset_path = os.path.join(self.path_of_main_launcher(), preset_name) + ".py:graph_manager"
graph_manager = short_dynamic_import(preset_path, ignore_module_case=True)
return graph_manager
def get_graph_manager_from_args(self, args):
# First get the graph manager for the customer-specified (or default) preset
graph_manager = self.preset_from_name(args.preset)
# Now override whatever config is specified in hyperparameters.
self.hyperparameters.apply_subset(graph_manager, "rl.")
# Set framework
# Note: Some graph managers (e.g. HAC preset) create multiple agents and the attribute is called agents_params
if hasattr(graph_manager, "agent_params"):
for network_parameters in graph_manager.agent_params.network_wrappers.values():
network_parameters.framework = args.framework
elif hasattr(graph_manager, "agents_params"):
for ap in graph_manager.agents_params:
for network_parameters in ap.network_wrappers.values():
network_parameters.framework = args.framework
return graph_manager
def _save_tf_model(self):
ckpt_dir = "/opt/ml/output/data/checkpoint"
model_dir = "/opt/ml/model"
import tensorflow as tf # importing tensorflow here so that MXNet docker image is compatible with this file.
# Re-Initialize from the checkpoint so that you will have the latest models up.
tf.train.init_from_checkpoint(
ckpt_dir, {"main_level/agent/online/network_0/": "main_level/agent/online/network_0"}
)
tf.train.init_from_checkpoint(
ckpt_dir, {"main_level/agent/online/network_1/": "main_level/agent/online/network_1"}
)
# Create a new session with a new tf graph.
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
sess.run(tf.global_variables_initializer()) # initialize the checkpoint.
# This is the node that will accept the input.
input_nodes = tf.get_default_graph().get_tensor_by_name(
"main_level/agent/main/online/" + "network_0/observation/observation:0"
)
# This is the node that will produce the output.
output_nodes = tf.get_default_graph().get_operation_by_name(
"main_level/agent/main/online/" + "network_1/ppo_head_0/policy"
)
# Save the model as a servable model.
tf.saved_model.simple_save(
session=sess,
export_dir="model",
inputs={"observation": input_nodes},
outputs={"policy": output_nodes.outputs[0]},
)
# Move to the appropriate folder. Don't mind the directory, this just works.
# rl-cart-pole is the name of the model. Remember it.
shutil.move("model/", model_dir + "/model/tf-model/00000001/")
# EASE will pick it up and upload to the right path.
print("Success")
def _save_onnx_model(self):
from .onnx_utils import fix_onnx_model
ckpt_dir = "/opt/ml/output/data/checkpoint"
model_dir = "/opt/ml/model"
# find latest onnx file
# currently done by name, expected to be changed in future release of coach.
glob_pattern = os.path.join(ckpt_dir, "*.onnx")
onnx_files = [file for file in glob.iglob(glob_pattern, recursive=True)]
if len(onnx_files) > 0:
extract_step = lambda string: int(
re.search("/(\d*)_Step.*", string, re.IGNORECASE).group(1)
)
onnx_files.sort(key=extract_step)
latest_onnx_file = onnx_files[-1]
# move to model directory
filepath_from = os.path.abspath(latest_onnx_file)
filepath_to = os.path.join(model_dir, "model.onnx")
shutil.move(filepath_from, filepath_to)
fix_onnx_model(filepath_to)
else:
screen.warning("No ONNX files found in {}".format(ckpt_dir))
@classmethod
def train_main(cls):
"""Entrypoint for training.
Parses command-line arguments and starts training.
"""
trainer = cls()
trainer.launch()
# Create model artifact for model.tar.gz
parser = trainer.sagemaker_argparser()
sage_args, unknown = parser.parse_known_args()
if sage_args.save_model == 1:
backend = os.getenv("COACH_BACKEND", "tensorflow")
if backend == "tensorflow":
trainer._save_tf_model()
if backend == "mxnet":
trainer._save_onnx_model()
class SageMakerCoachLauncher(SageMakerCoachPresetLauncher):
"""
Older version of the launcher that doesn't use preset, but instead effectively has a single preset built in.
"""
def __init__(self):
super().__init__()
screen.warning("DEPRECATION WARNING: Please switch to SageMakerCoachPresetLauncher")
# TODO: Remove this whole class when nobody's using it any more.
def define_environment(self):
return NotImplementedEror(
"Sub-class must define environment e.g. GymVectorEnvironment(level='your_module:YourClass')"
)
def get_graph_manager_from_args(self, args):
"""Returns the GraphManager object for coach to use to train by calling improve()"""
# NOTE: TaskParameters are not configurable at this time.
# Visualization
vis_params = VisualizationParameters()
self.config_visualization(vis_params)
self.hyperparameters.apply_subset(vis_params, "vis_params.")
# Schedule
schedule_params = ScheduleParameters()
self.config_schedule(schedule_params)
self.hyperparameters.apply_subset(schedule_params, "schedule_params.")
# Agent
agent_params = self.define_agent()
self.hyperparameters.apply_subset(agent_params, "agent_params.")
# Environment
env_params = self.define_environment()
self.hyperparameters.apply_subset(env_params, "env_params.")
graph_manager = BasicRLGraphManager(
agent_params=agent_params,
env_params=env_params,
schedule_params=schedule_params,
vis_params=vis_params,
)
return graph_manager
def config_schedule(self, schedule_params):
pass
def define_agent(self):
raise NotImplementedError(
"Subclass must create define_agent() method which returns an AgentParameters object. e.g.\n"
" return rl_coach.agents.dqn_agent.DQNAgentParameters()"
)
def config_visualization(self, vis_params):
vis_params.dump_gifs = True
vis_params.video_dump_methods = [
SelectedPhaseOnlyDumpFilter(RunPhase.TEST),
MaxDumpFilter(),
]
vis_params.print_networks_summary = True
return vis_params